/**
 * git clone https://github.com/hou-dao/deom.git
 * ---
 * Written by Houdao Zhang 
 * mailto: houdao@connect.ust.hk
 */
#include "deom.hpp"
#include <cmath>
#include <sstream>


static int at(const int nt, const int it0, const int it1, const int it2, const int it3) {
  return nt * (nt * (nt * it3 + it2) + it1) + it0;
}

void respPcc(const double tf, const int nt, const double dt,
             const cx_mat &rho0, const cube &dipo_mat, const string &dipo_lrc,
             const char &sch_hei, const syst &s, const bath &b, const hidx &h) {

  const int nt0 = nt;
  const int nt1 = nt;
  const int nt2 = nt;
  const int nt3 = nt;
  const double dt0 = tf / nt;
  const double dt1 = tf / nt;
  const double dt2 = tf / nt;
  const double dt3 = tf / nt;
  const int mt0 = floor(dt0 / dt);
  const int mt1 = floor(dt1 / dt);
  const int mt2 = floor(dt2 / dt);
  const int mt3 = floor(dt3 / dt);
  const double dt0_res = dt0 - dt * mt0;
  const double dt1_res = dt1 - dt * mt1;
  const double dt2_res = dt2 - dt * mt2;
  const double dt3_res = dt3 - dt * mt3;

  cx_vec ft = zeros<cx_vec>(nt0 * nt1 * nt2 * nt3);

  deom d0(s, b, h);
  cx_cube rho_t0 = zeros<cx_cube>(d0.nsys, d0.nsys, d0.nmax);
  rho_t0.slice(0) = rho0;

  if (sch_hei == 's') {                   // sch-pic
    for (int it0 = 0; it0 < nt0; ++it0) { // t_2' or t_2'+tau_2
      deom d1(d0);
      cx_cube rho_t1 = zeros<cx_cube>(d1.nsys, d1.nsys, d1.nmax);
      d1.oprAct(rho_t1, dipo_mat.slice(0), rho_t0, dipo_lrc[0]);
      for (int it1 = 0; it1 < nt1; ++it1) { // tau_2 or -tau_2
        deom d2(d1);
        cx_cube rho_t2 = zeros<cx_cube>(d2.nsys, d2.nsys, d2.nmax);
        d2.oprAct(rho_t2, dipo_mat.slice(1), rho_t1, dipo_lrc[1]);
        for (int it2 = 0; it2 < nt2; ++it2) {
          deom d3(d2);
          cx_cube rho_t3 = zeros<cx_cube>(d3.nsys, d3.nsys, d3.nmax);
          d3.oprAct(rho_t3, dipo_mat.slice(2), rho_t2, dipo_lrc[2]);
          for (int it3 = 0; it3 < nt3; ++it3) {
            ft(at(nt, it0, it1, it2, it3)) = trace(dipo_mat.slice(3) * rho_t3.slice(0));
            // printf("In sch-pic: it3=%d, nddo=%d, lddo=%d\n", it3, d3.nddo, d3.lddo);
            double t3 = it3 * dt3;
            for (int jt3 = 0; jt3 < mt3; ++jt3) {
              d3.rk4(rho_t3, t3, dt);
              t3 += dt;
            }
            d3.rk4(rho_t3, t3, dt3_res);
            t3 += dt3_res;
          }
          // printf("In sch-pic: it2=%d, nddo=%d, lddo=%d\n", it2, d2.nddo, d2.lddo);
          double t2 = it2 * dt2;
          for (int jt2 = 0; jt2 < mt2; ++jt2) {
            d2.rk4(rho_t2, t2, dt);
            t2 += dt;
          }
          d2.rk4(rho_t2, t2, dt2_res);
          t2 += dt2_res;
        }
        // printf("In sch-pic: it1=%d, nddo=%d, lddo=%d\n", it1, d1.nddo, d1.lddo);
        double t1 = it1 * dt1;
        for (int jt1 = 0; jt1 < mt1; ++jt1) {
          d1.rk4(rho_t1, t1, dt);
          t1 += dt;
        }
        d1.rk4(rho_t1, t1, dt1_res);
        t1 += dt1_res;
      }
      printf("In sch-pic: it0=%d, nddo=%d, lddo=%d\n", it0, d0.nddo, d0.lddo);
      double t0 = it0 * dt0;
      for (int jt0 = 0; jt0 < mt0; ++jt0) {
        d0.rk4(rho_t0, t0, dt);
        t0 += dt;
      }
      d0.rk4(rho_t0, t0, dt0_res);
      t0 += dt0_res;
    }
  } else if (sch_hei == 'h') { // hei-pic
    // Heisenberg picture dynamics
    deom d3(s, b, h);
    cx_cube opr_t3 = zeros<cx_cube>(d3.nsys, d3.nsys, d3.nmax);
    d3.iniHei(opr_t3, dipo_mat.slice(3));
    for (int it3 = 0; it3 < nt3; ++it3) {
      deom d2(d3);
      cx_cube opr_t2 = zeros<cx_cube>(d2.nsys, d2.nsys, d2.nmax);
      d2.oprAct(opr_t2, dipo_mat.slice(2), opr_t3, 'r'); // always act on the right
      for (int it2 = 0; it2 < nt2; ++it2) {
        field<ivec> keys(d2.nddo);
        for (int iddo = 0; iddo < d2.nddo; ++iddo) {
          keys(iddo) = d2.keys(iddo).key;
        }
        const cx_cube &oprs = opr_t2.head_slices(d2.nddo);
        stringstream ss1, ss2;
        ss1 << "key_t3_" << it3 << "_t2_" << it2 << ".tmp";
        ss2 << "opr_t3_" << it3 << "_t2_" << it2 << ".tmp";
        keys.save(ss1.str(), arma_binary);
        oprs.save(ss2.str(), arma_binary);
        //
        printf("In hei-pic: it2=%d, nddo=%d, lddo=%d\n", it2, d2.nddo, d2.lddo);
        double t2 = it2 * dt2;
        for (int jt2 = 0; jt2 < mt2; ++jt2) {
          d2.rk4(opr_t2, t2, dt, 'h');
          t2 += dt;
        }
        d2.rk4(opr_t2, t2, dt2_res, 'h');
      }
      printf("In hei-pic: it3=%d, nddo=%d, lddo=%d\n", it3, d3.nddo, d3.lddo);
      double t3 = it3 * dt3;
      for (int jt3 = 0; jt3 < mt3; ++jt3) {
        d3.rk4(opr_t3, t3, dt, 'h');
        t3 += dt;
      }
      d3.rk4(opr_t3, t3, dt3_res, 'h');
    }
    // Schrodinger picture dynamics
    for (int it0 = 0; it0 < nt0; ++it0) {
      deom d1(d0);
      cx_cube rho_t1 = zeros<cx_cube>(d1.nsys, d1.nsys, d1.nmax);
      d1.oprAct(rho_t1, dipo_mat.slice(0), rho_t0, dipo_lrc[0]);
      for (int it1 = 0; it1 < nt1; ++it1) {
        cx_cube rho_t2 = zeros<cx_cube>(d1.nsys, d1.nsys, d1.nddo);
        d1.oprAct(rho_t2, dipo_mat.slice(1), rho_t1, dipo_lrc[1]);
        for (int it2 = 0; it2 < nt2; ++it2) {
          for (int it3 = 0; it3 < nt3; ++it3) {
            stringstream ss1, ss2;
            field<ivec> keys;
            cx_cube oprs;
            ss1 << "key_t3_" << it3 << "_t2_" << it2 << ".tmp";
            ss2 << "opr_t3_" << it3 << "_t2_" << it2 << ".tmp";
            keys.load(ss1.str(), arma_binary);
            oprs.load(ss2.str(), arma_binary);
            const int nddo = keys.n_rows;
            cx_double ctmp = 0.0;
            for (int iado = 0; iado < nddo; ++iado) {
              TrieNode *p = d1.tree.find(keys(iado));
              if (p && p->rank >= 0) {
                const int jado = p->rank;
                ctmp += trace(oprs.slice(iado) * rho_t2.slice(jado));
              }
            }
            ft(at(nt, it0, it1, it2, it3)) = ctmp;
          }
        }
        // printf("In sch-pic: it1=%d, nddo=%d, lddo=%d\n", it1, d1.nddo, d1.lddo);
        double t1 = it1 * dt1;
        for (int jt1 = 0; jt1 < mt1; ++jt1) {
          d1.rk4(rho_t1, t1, dt);
          t1 += dt;
        }
        d1.rk4(rho_t1, t1, dt1_res);
        t1 += dt1_res;
      }
      printf("In sch-pic: it0=%d, nddo=%d, lddo=%d\n", it0, d0.nddo, d0.lddo);
      double t0 = it0 * dt0;
      for (int jt0 = 0; jt0 < mt0; ++jt0) {
        d0.rk4(rho_t0, t0, dt);
        t0 += dt;
      }
      d0.rk4(rho_t0, t0, dt0_res);
      t0 += dt0_res;
    }
    // clean tmp files
    for (int it3 = 0; it3 < nt3; ++it3) {
      for (int it2 = 0; it2 < nt2; ++it2) {
        stringstream ss1, ss2;
        ss1 << "key_t3_" << it3 << "_t2_" << it2 << ".tmp";
        ss2 << "opr_t3_" << it3 << "_t2_" << it2 << ".tmp";
        const string &s1 = ss1.str();
        const string &s2 = ss2.str();
        remove(s1.c_str());
        remove(s2.c_str());
      }
    }
  }
  // write to file
  ft.save("barePCC_2.mat", raw_ascii);
}
